from prodigyopt.prodigy import Prodigy as ProdigyOfficialImpl

class Prodigy(ProdigyOfficialImpl):
    def __init__(self, params, lr=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-8, weight_decay=0, decouple=True,
                 use_bias_correction=False, safeguard_warmup=False, d0=1e-6, d_coef=1.0, growth_rate=float('inf'),
                 fsdp_in_use=False):
        super().__init__(params, lr, betas, beta3, eps, weight_decay, decouple, use_bias_correction, safeguard_warmup,
                         d0, d_coef, growth_rate, fsdp_in_use)

    def has_d_estimator(self):
        return True

    def calculate_d_estimation_error(self, actual_d):
        return actual_d / self.param_groups[0]['d_max']
